Skip to content

fix: move callCapturing to *ModelInner (callers use model.callCapturing)#32

Merged
solderzzc merged 1 commit into
mainfrom
feat/dflash-public-api-v2
Apr 24, 2026
Merged

fix: move callCapturing to *ModelInner (callers use model.callCapturing)#32
solderzzc merged 1 commit into
mainfrom
feat/dflash-public-api-v2

Conversation

@solderzzc
Copy link
Copy Markdown
Member

Hotfix on top of #31. The DFlash conformances in SwiftLM call model.callCapturing() where model is *ModelInner, so callCapturing must live on the inner class, not the outer wrapper. Verified builds clean locally.

The DFlashTargetModel conformances in SwiftLM call model.callCapturing()
where model is *ModelInner, not the outer *Model wrapper.
Copilot AI review requested due to automatic review settings April 24, 2026 04:20
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Moves the callCapturing API from the outer *Model wrappers onto the corresponding *ModelInner types so downstream DFlash conformances (which hold *ModelInner) can call model.callCapturing(...) directly.

Changes:

  • Add public callCapturing(_:cache:captureLayerIDs:) to LlamaModelInner, Qwen3ModelInner, and Qwen3MoEModelInner
  • Remove public callCapturing from LlamaModel, Qwen3Model, and Qwen3MoEModel

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
Libraries/MLXLLM/Models/Llama.swift Adds callCapturing to LlamaModelInner and removes it from LlamaModel.
Libraries/MLXLLM/Models/Qwen3.swift Adds callCapturing to Qwen3ModelInner and removes it from Qwen3Model.
Libraries/MLXLLM/Models/Qwen3MoE.swift Adds callCapturing to Qwen3MoEModelInner and removes it from Qwen3MoEModel.
Comments suppressed due to low confidence (1)

Libraries/MLXLLM/Models/Qwen3MoE.swift:289

  • callCapturing was removed from the public Qwen3MoEModel wrapper. If this API was already exposed in #31, consider reintroducing it here as a thin forwarding method to model.callCapturing(...) (optionally deprecated) to avoid breaking existing call sites that use the outer model type.
    public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
        var out = model(inputs, cache: cache)
        if let lmHead {
            out = lmHead(out)
        } else {
            out = model.embedTokens.asLinear(out)
        }
        return out
    }

    public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
        var sanitizedWeights = weights

        if configuration.tieWordEmbeddings {
            sanitizedWeights["lm_head.weight"] = nil

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -182,29 +203,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
}
}

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing was removed from the public LlamaModel wrapper. Since it was introduced as part of the public API in the prior PR, consider keeping a forwarding LlamaModel.callCapturing(...) that delegates to model.callCapturing(...) (optionally marking it deprecated) to avoid a breaking change for any clients already calling the wrapper type.

Suggested change
@available(*, deprecated, message: "Use the underlying model.callCapturing(...) instead.")
public func callCapturing(_ inputs: MLXArray, cache: [KVCache]?) -> Any {
model.callCapturing(inputs, cache: cache)
}

Copilot uses AI. Check for mistakes.
@@ -208,29 +229,6 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
return out
}

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing was removed from the public Qwen3Model wrapper. To preserve source compatibility for users that may have started calling Qwen3Model.callCapturing(...) after #31, consider keeping a wrapper method that forwards to model.callCapturing(...) (and optionally deprecate it if you want callers to move to the inner type).

Suggested change
public func callCapturing(_ inputs: MLXArray, cache: [KVCache]?) -> (MLXArray, [MLXArray]) {
var (out, captured) = model.callCapturing(inputs, cache: cache)
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return (out, captured)
}

Copilot uses AI. Check for mistakes.
@solderzzc
Copy link
Copy Markdown
Member Author

Noted on the source compatibility concern. Since PR #31 was merged only minutes before this hotfix and callCapturing is an internal API consumed solely by DFlashTargetModel conformances in SwiftLM (no released binary has shipped against it), we're skipping the deprecated forwarder to keep the API surface clean. If external callers appear in future we can add it then.

@solderzzc solderzzc merged commit 694806d into main Apr 24, 2026
10 checks passed
@solderzzc solderzzc deleted the feat/dflash-public-api-v2 branch April 24, 2026 04:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants